DeconvGradFilter

计算反卷积(Deconvolution / Transposed Convolution)算子的权重梯度,用于反向传播阶段。 该算子根据输出梯度 dy 与输入特征 x,通过 im2row + GEMM 的方式累加得到卷积核梯度 dw,并支持分组(group)计算。

\[\frac{\partial W}{\partial L} = \sum_{b=0}^{B-1} \text{Im2Row}(dY_b) \cdot X_b\]

其中每个 Group 独立计算,最终在 Group 维度上拼接。

输入:
  • dy_data - 输出特征梯度地址,形状为 [batch, out_h, out_w, out_c]

  • x_data - 输入特征地址,形状为 [batch, in_h, in_w, in_c]

  • param - 参数数组地址,用于描述反卷积计算相关参数与工作空间。
    • param[0] : input_batch

    • param[1] : in_h

    • param[2] : in_w

    • param[3] : in_c

    • param[4] : batch

    • param[5] : out_h

    • param[6] : out_w

    • param[7] : out_c

    • param[8] : kernel_h

    • param[9] : kernel_w

    • param[16] : group

    • param[17] : im2row 工作缓冲区地址

  • core_mask - 核掩码(仅适用于共享存储版本)。

输出:
  • dw_data - 权重梯度输出地址,布局为 [group, out_c/group * k_h * k_w, in_c/group]

支持平台:

FT78NE MT7004

备注

  • FT78NE 仅支持 fp 类型

  • MT7004 支持 hp, fp 类型

  • 输入与输出数据格式为 NHWC

共享存储版本:

void hp_deconv_grad_filter_s(half *dy_data, half *x_data, half *dw_data, long long *param, int core_mask)
void fp_deconv_grad_filter_s(float *dy_data, float *x_data, float *dw_data, long long *param, int core_mask)

C调用示例:

 1//FT78NE示例
 2#include <stdio.h>
 3#include <deconvgradfilter.h>
 4
 5int main(int argc, char* argv[]) {
 6    float *x = (float *)0x10010000;
 7    float *dy = (float *)0x10020000;
 8    float *dw = (float *)0x10030000;
 9    float *temp_space = (float *)0x10050000;
10
11    //参数
12    int input_batch = 16;
13    int output_batch = input_batch;
14
15    // Input: 2x2, 1 Channel
16    int input_h = 4;
17    int input_w = 4;
18    int input_channel = 4;
19
20    // Output Grad (dy): 2x2, 1 Channel
21    int output_h = input_h;
22    int output_w = input_w;
23    int output_channel = input_channel;
24
25    // Kernel: 2x2
26    int kernel_h = 4;
27    int kernel_w = 4;
28
29    int stride_h = 1;
30    int stride_w = 1;
31    int pad_u = 0;
32    int pad_l = 0;
33    int dilation_h = 1;
34    int dilation_w = 1;
35    int group = 1;
36
37    srand(seed++);
38    int i;
39    for(i = 0; i < input_batch * input_h * input_w * input_channel; ++i) {
40        x[i] = (float)(rand()%10)/50.0f + 0.1f;
41    }
42    for(i = 0; i < output_batch *output_h * output_w * output_channel; ++i) {
43        dy[i] = (float)(rand()%10)/50.0f + 0.1f;
44    }
45    for(i = 0; i < kernel_h * kernel_w * output_channel * input_channel/group; ++i) {
46        dw[i] = 0.0f;
47    }
48
49    // 1. 设置参数
50    long long params[20];
51    params[0] = input_batch;
52    params[1] = input_h;
53    params[2] = input_w;
54    params[3] = input_channel;
55    params[4] = output_batch;
56    params[5] = output_h;
57    params[6] = output_w;
58    params[7] = output_channel;
59    params[8] = kernel_h;
60    params[9] = kernel_w;
61    params[10] = stride_h;
62    params[11] = stride_w;
63    params[12] = pad_u;
64    params[13] = pad_l;
65    params[14] = dilation_h;
66    params[15] = dilation_w;
67    params[16] = group;
68    params[17] = (long long)temp_space;
69
70    int core_mask = 0b1111;
71    /*性能统计*/
72    fp_deconv_grad_filter_s(dy, x, dw, params, core_mask);
73    return 0;
74}

私有存储版本:

void hp_deconv_grad_filter_p(half *dy_data, half *x_data, half *dw_data, long long *param)
void fp_deconv_grad_filter_p(float *dy_data, float *x_data, float *dw_data, long long *param)

C调用示例:

 1//FT78NE示例
 2#include <stdio.h>
 3#include <deconvgradfilter.h>
 4
 5int main(int argc, char* argv[]) {
 6    float *x = (float *)0x10010000;
 7    float *dy = (float *)0x10020000;
 8    float *dw = (float *)0x10030000;
 9    float *temp_space = (float *)0x10050000;
10
11    //参数
12    int input_batch = 16;
13    int output_batch = input_batch;
14
15    // Input: 2x2, 1 Channel
16    int input_h = 4;
17    int input_w = 4;
18    int input_channel = 4;
19
20    // Output Grad (dy): 2x2, 1 Channel
21    int output_h = input_h;
22    int output_w = input_w;
23    int output_channel = input_channel;
24
25    // Kernel: 2x2
26    int kernel_h = 4;
27    int kernel_w = 4;
28
29    int stride_h = 1;
30    int stride_w = 1;
31    int pad_u = 0;
32    int pad_l = 0;
33    int dilation_h = 1;
34    int dilation_w = 1;
35    int group = 1;
36
37    srand(seed++);
38    int i;
39    for(i = 0; i < input_batch * input_h * input_w * input_channel; ++i) {
40        x[i] = (float)(rand()%10)/50.0f + 0.1f;
41    }
42    for(i = 0; i < output_batch *output_h * output_w * output_channel; ++i) {
43        dy[i] = (float)(rand()%10)/50.0f + 0.1f;
44    }
45    for(i = 0; i < kernel_h * kernel_w * output_channel * input_channel/group; ++i) {
46        dw[i] = 0.0f;
47    }
48
49    // 1. 设置参数
50    long long params[20];
51    params[0] = input_batch;
52    params[1] = input_h;
53    params[2] = input_w;
54    params[3] = input_channel;
55    params[4] = output_batch;
56    params[5] = output_h;
57    params[6] = output_w;
58    params[7] = output_channel;
59    params[8] = kernel_h;
60    params[9] = kernel_w;
61    params[10] = stride_h;
62    params[11] = stride_w;
63    params[12] = pad_u;
64    params[13] = pad_l;
65    params[14] = dilation_h;
66    params[15] = dilation_w;
67    params[16] = group;
68    params[17] = (long long)temp_space;
69
70    fp_deconv_grad_filter_p(dy, x, dw, params);
71    return 0;
72}